import sys
import pickle
import pandas as pd
import statistics
import numpy as np
sys.path.append(sys.argv[3])
from signal_udfs import unpickle
from signal_udfs import plot_all_signals
from matplotlib import pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings('ignore')

input_handle = sys.argv[1]
output_handle = sys.argv[2]
z_traceout_handle = sys.argv[4]


#Unpickle trace dictionary
pickles = unpickle([sys.argv[5]])
trace_dict = pickles[0]
pickles = unpickle([sys.argv[6]])
dff0_dict = pickles[0]
pickles = unpickle([sys.argv[7]])
base_epoch = pickles[0]

#Create lists for baseline trace means and stdevs
df = pd.DataFrame.from_dict(trace_dict['Baseline Trace'])
base_means = df.mean(axis = 1, skipna = True)
base_means = base_means.to_list()
base_stdevs = df.std(axis = 1, skipna = True)
base_stdevs = base_stdevs.to_list()

dff0_mean = statistics.mean(dff0_dict['DF/F0'])
dff0_stdev = statistics.stdev(dff0_dict['DF/F0'])

#Create dictionary keys for z-score traces
names = ['Event Baseline Z-Score', 'Event Z-Score', 'Event Full Z-Score']
for name in names:
        trace_dict[name] = []

#Calculate and z-score for full DF/F0
dff0_zscore = []
for x in dff0_dict['DF/F0']:
        dff0_zscore.append((x - dff0_mean)/dff0_stdev)

#Find peaks using derivative method
derivative = np.gradient(dff0_zscore)
peak_idx = np.where((derivative[:-1] > 0) & (derivative[1:] < 0))[0]

#Find timestamps and values for peaks
peak_ts = []
peak_z = []
for idx in peak_idx:
        peak_ts.append(dff0_dict['Time'][idx])
        peak_z.append(dff0_zscore[idx])

#Filter out peaks
keep_peak_ts = []
keep_peak_z = []
keep_peak_idx = []

for z_num, z in enumerate(peak_z):
        if z_num == 0:
                if z >= 0.5:
                        keep_peak_idx.append(peak_idx[z_num])
                        keep_peak_ts.append(peak_ts[z_num])
                        keep_peak_z.append(peak_z[z_num])
        if z_num == len(peak_z) - 1:
                if z >= 0.5 and peak_ts[z_num] - peak_ts[z_num - 1] >= 0.27:
                        keep_peak_idx.append(peak_idx[z_num])
                        keep_peak_ts.append(peak_ts[z_num])
                        keep_peak_z.append(peak_z[z_num])
        else:
                if z >= 0.5 and peak_ts[z_num + 1] - peak_ts[z_num] >= 0.27 and peak_ts[z_num] - peak_ts[z_num - 1] >= 0.27:
                        keep_peak_idx.append(peak_idx[z_num])
                        keep_peak_ts.append(peak_ts[z_num])
                        keep_peak_z.append(peak_z[z_num])
                        
names = ['Peak Index', 'Peak Time', 'Peak Z']
peak_dict = {name:[] for name in names}
peak_dict['Peak Index'] = keep_peak_idx
peak_dict['Peak Time'] = keep_peak_ts
peak_dict['Peak Z'] = keep_peak_z

#Plot detected peaks
plt.plot(dff0_dict['Time'], dff0_zscore)
for idx in keep_peak_idx:
        plt.plot(dff0_dict['Time'][idx], dff0_zscore[idx], marker = 'o')
plt.xlabel('Time (s)')
plt.ylabel('DF/F0 (Z-Score)')
plt.savefig(output_handle[:-4] + '_transients.png')
plt.close()


#Calculate z-scores for events

for event_num, event in enumerate(trace_dict['Name']):
        base_z = []
        event_z = []
        full_z = []
        for num_num, num in enumerate(trace_dict['Baseline Trace'][event_num]):
                base_z.append((num - base_means[event_num])/base_stdevs[event_num])
                full_z.append((num - base_means[event_num])/base_stdevs[event_num])
        for num_num, num in enumerate(trace_dict['Event Trace'][event_num]):
                event_z.append((num - base_means[event_num])/base_stdevs[event_num])
                full_z.append((num - base_means[event_num])/base_stdevs[event_num])
        trace_dict['Event Baseline Z-Score'].append(base_z)
        trace_dict['Event Z-Score'].append(event_z)
        trace_dict['Event Full Z-Score'].append(full_z)
        
#Plot Full Z-Score traces for all events
for event_num, event in enumerate(trace_dict['Name']):
        plt.plot(trace_dict['Full Time Norm'][event_num], trace_dict['Event Full Z-Score'][event_num])
        plt.xlabel('Time (s)')
        plt.ylabel('DF/F0 (Z-Score)')
        plt.title(event + ' Event')
        plt.savefig(z_traceout_handle[:-4] + '_' + trace_dict['Name'][event_num] + '_4s_peri-event_zscore_trace.png')
        plt.close()

#Save event trace outputs
longest_len = 0
longest_line = []
temp_time = []
current_ts = 0
for trace_num, trace in enumerate(trace_dict['Full Time Norm']):
        if trace_num == 0:
                for num_num, num in enumerate(trace):
                        if num_num > 0:
                                temp_time.append(num - trace[num_num - 1])
        if len(trace) > longest_len:
                longest_len = len(trace)
                longest_line = trace

col_names = ['Time (s)']
for name in trace_dict['Name']:
        col_names.append(name)

df_time = pd.DataFrame(longest_line)
df_trace = pd.DataFrame.from_dict(trace_dict['Event Full Z-Score']).T
df_trace_out = pd.concat([df_time, df_trace], axis = 1, ignore_index = True)
df_trace_out.columns = col_names
df_trace_out.to_csv(output_handle[:-4] + '_4s_peri-event_zscore_traces.csv', index = False)


#Pickle trace_dict and peak_dict
pickle.dump(trace_dict,open('trace_dict_4s_peri.pkl','wb'))
pickle.dump(peak_dict,open('peak_dict.pkl','wb'))
